{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://code.aliyun.com/qhduan/zh-bert/raw/0fb1d96ec2133fe25e66bee12fe387cbe1e52938/vocab.txt\n",
    "# !pip install tokenizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import math\n",
    "import functools\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_addons as tfa\n",
    "from tokenizers import BertWordPieceTokenizer\n",
    "\n",
    "from model import GPT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_learning_rate(learning_rate=6e-4,\n",
    "                      warmup_steps=20_0000,\n",
    "                      decay_steps=200_0000,\n",
    "                      alpha=0.0):\n",
    "    def decayed_learning_rate(step=1):\n",
    "        if step <= warmup_steps:\n",
    "            mult = step / float(warmup_steps)\n",
    "        else:\n",
    "            progress = (step - warmup_steps) / (decay_steps - warmup_steps)\n",
    "            mult = 0.5 * (1 + math.cos(math.pi * progress))\n",
    "            mult = max(0.1, mult)\n",
    "        return learning_rate * mult\n",
    "    return decayed_learning_rate\n",
    "\n",
    "\n",
    "def data_generator(path, batch_size=4):\n",
    "    batch = []\n",
    "    with open(path) as fp:\n",
    "        for line in fp:\n",
    "            line = line.strip()\n",
    "            if len(line) <= 0:\n",
    "                continue\n",
    "            x = [\n",
    "                tokenizer.token_to_id(x)\n",
    "                for x in line.split('\\t')]\n",
    "            batch.append(x)\n",
    "            if len(batch) >= batch_size:\n",
    "                batch = tf.ragged.constant(batch)\n",
    "                batch = batch.to_tensor()\n",
    "                yield batch[:, :-1], batch[:, 1:]\n",
    "                batch = []\n",
    "    if len(batch) > 0:\n",
    "        batch = tf.ragged.constant(batch)\n",
    "        batch = batch.to_tensor()\n",
    "        yield batch[:, :-1], batch[:, 1:]\n",
    "        batch = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertWordPieceTokenizer('vocab.txt')\n",
    "\n",
    "gpt = GPT(vocab_size=tokenizer.get_vocab_size(),\n",
    "          layer_size=12,\n",
    "          block_size=1024,\n",
    "          embedding_size=768,\n",
    "          num_attention_heads=12,\n",
    "          embedding_dropout=0.1,\n",
    "          attention_dropout=0.1,\n",
    "          residual_dropout=0.1)\n",
    "\n",
    "gpt.compile(\n",
    "    optimizer=tfa.optimizers.AdamW(\n",
    "        weight_decay=0.1,\n",
    "        learning_rate=get_learning_rate(),\n",
    "        beta_1=0.9,\n",
    "        beta_2=0.95,\n",
    "        epsilon=1e-8,\n",
    "        name='AdamW',\n",
    "        clipnorm=1.0\n",
    "    )\n",
    ")\n",
    "\n",
    "gpt._set_inputs(tf.keras.layers.Input((None,), dtype=tf.int32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.from_generator(\n",
    "    functools.partial(\n",
    "        data_generator,\n",
    "        path='/home/qhduan/DATA10T/seagate/DATASETS/bd-chat/chat.txt'),\n",
    "    output_types=(tf.int32, tf.int32),\n",
    "    output_shapes=(tf.TensorShape([None, None]), tf.TensorShape([None, None]))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt.fit(dataset, epochs=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpt.save('/tmp/gpt3')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gpt_load = tf.keras.models.load_model('/tmp/gpt3')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
